# Copyright (C) 2025 Zensar Technologies Private Ltd.
# SPDX-License-Identifier: Apache-2.0

import os

import requests
from langchain_community.llms import VLLMOpenAI

from comps import CustomLogger, OpeaComponentRegistry
from comps.cores.proto.api_protocol import ArbPostHearingAssistantChatCompletionRequest

from .common import *

logger = CustomLogger("arb_post_hearing_assistant_vllm")
logflag = os.getenv("LOGFLAG", False)
LLM_ENDPOINT = os.getenv("LLM_ENDPOINT", "http://vllm-server:80")
MODEL_NAME = os.getenv("LLM_MODEL_ID", "meta-llama/Meta-Llama-3-8B-Instruct")


@OpeaComponentRegistry.register("OpeaArbPostHearingAssistantVllm")
class OpeaArbPostHearingAssistantVllm(OpeaArbPostHearingAssistant):
    """A specialized OPEA OpeaArbPostHearingAssistantVllm component derived from OpeaArbPostHearingAssistantVllm for interacting with vLLM services based on Lanchain VLLMOpenAI API.

    Attributes:
        client (vLLM): An instance of the vLLM client for text generation.
    """

    def check_health(self) -> bool:
        """Checks the health of the vLLM LLM service.

        Returns:
            bool: True if the service is reachable and healthy, False otherwise.
        """

        try:
            response = requests.get(f"{self.llm_endpoint}/health")
            if response.status_code == 200:
                return True
            else:
                return False
        except Exception as e:
            logger.error(e)
            logger.error("Health check failed")
            return False

    async def invoke(self, input: ArbPostHearingAssistantChatCompletionRequest):
        """Invokes the vLLM LLM service to generate  ArbPostHearing summary output for the provided input.

        Args:
            input (ArbPostHearingAssistantChatCompletionRequest): The input text(s).
        """
        headers = {}
        self.client = VLLMOpenAI(
            openai_api_key="EMPTY",
            openai_api_base=self.llm_endpoint + "/v1",
            model_name=MODEL_NAME,
            default_headers=headers,
            max_tokens=input.max_tokens if input.max_tokens else 1024,
            top_p=input.top_p if input.top_p else 0.95,
            temperature=input.temperature if input.temperature else 0.01,
            request_timeout=float(input.timeout) if input.timeout is not None else None,
        )
        result = await self.generate(input, self.client)

        return result
